

%% ========================================================================
%% BASIC SETUP %% BASIC SETUP %% BASIC SETUP %% BASIC SETUP %%
%% BASIC SETUP %% BASIC SETUP %% BASIC SETUP %% BASIC SETUP %%
%% BASIC SETUP %% BASIC SETUP %% BASIC SETUP %% BASIC SETUP %%
%% ========================================================================
%% directories
clear
fundir      = pwd;%[maindir 'TMS_code/'];
datadir     = pwd;%[maindir 'TMS_code/'];
savedir     = pwd;%[maindir];
addpath(fundir);
cd(fundir);
defaultPlotParameters

%% load data
sub = load_TMS_v1([datadir '/TMS_horizonTask.csv']);

% remove bad subjects
[sub, sub_bad] = removeBadSubjects_TMS_v1(sub);



%% ========================================================================
%% MODEL-FREE ANALYSIS %% MODEL-FREE ANALYSIS %% MODEL-FREE ANALYSIS %%
%% MODEL-FREE ANALYSIS %% MODEL-FREE ANALYSIS %% MODEL-FREE ANALYSIS %%
%% MODEL-FREE ANALYSIS %% MODEL-FREE ANALYSIS %% MODEL-FREE ANALYSIS %%
%% ========================================================================

%% Figure 2: The reward-information confound 
clear rr1 rr6
for sn = 1:length(sub)
    
    dn = sign(sub(sn).n2 - sub(sn).n1);
    dm = sign(sub(sn).o2 - sub(sn).o1);
    gl = sub(sn).gameLength;
    i1 = gl == 5;
    i6 = gl == 10;
    iv = strcmp(sub(sn).expt_name, 'vertex');
    ir = strcmp(sub(sn).expt_name, 'RFPC');
    
    for i = 4:10
        if sum(iv) == 0
            rr1(i-3,1,sn) = nan;
            rr6(i-3,1,sn) = nan;
        else
            rr1(i-3,1,sn) = corr(dn(i1&iv,i), dm(i1&iv,i));
            rr6(i-3,1,sn) = corr(dn(i6&iv,i), dm(i6&iv,i));
        end
        rr1(i-3,2,sn) = corr(dn(i1&ir,i), dm(i1&ir,i));
        rr6(i-3,2,sn) = corr(dn(i6&ir,i), dm(i6&ir,i));
    end
    
end

rr1(2,:) = nan;
rr6(end,:) = nan;

figure(1); clf;
set(gcf, 'position', [588   576   400   250]);

ax = easy_gridOfEqualFigures([0.2 0.1], [0.2 0.03]);
axes(ax(1)); hold on;
plot([0 7], [0 0], 'k--', 'linewidth', 1)

M1 = nanmean(rr1,3);
S1 = nanstd(rr1,[],3)/sqrt(size(rr1,3));
M6 = nanmean(rr6,3);
S6 = nanstd(rr6,[],3)/sqrt(size(rr6,3));
e = plot(M1);
e(:,2) = plot(M6);
e = e';
set(e(1,:), 'color', AZblue);
set(e(2,:), 'color', AZred);
set(e(:,1), 'linestyle', '-','marker', '+')
set(e(:,2), 'linestyle', '--','marker', 'x')
set(e, 'markersize', 10, 'linewidth', 1)

set(ax, 'xtick', [1:6], 'xlim', [0.5 6.5], 'tickdir', 'out')
xlabel('free-trial number')
ylabel({'correlation between' 'reward and information'})
legend(e([1 3 2 4]), ...
    {'vertex, horizon 1' 'RFPC, horizon 1', 'vertex, horizon 6' 'RFPC, horizon 6'}, ...
    'location', 'southeast')

%% Figure 3: Model-free analysis of the first free-choice trial 

% vertex stimulation is 1, RFPC stimulation is 2
clear p_highInfo p_lowMean p_repeat p_repeat13
GL = [5 10];
for sn = 1:length(sub)
    for h = 1:length(GL)
        i_RFPC = strcmp(sub(sn).expt_name, 'RFPC');
        i_vertex = strcmp(sub(sn).expt_name, 'vertex');
        
        
        idx = i_vertex;
        ind = (sub(sn).gameLength == GL(h)) & idx;
        p_highInfo(h,sn,1) = nanmean(sub(sn).hi(ind,5));
        ind = (sub(sn).gameLength == GL(h)) & (sub(sn).n2(:,4)~=sub(sn).n1(:,4)) & idx;
        p_repeat13(h, sn,1) = nanmean(sub(sn).rep(ind,5));
        ind = (sub(sn).gameLength == GL(h)) & (sub(sn).n2(:,4)==sub(sn).n1(:,4)) & idx;
        p_lowMean(h,sn,1) = nanmean(sub(sn).lm(ind,5));
        p_repeat(h, sn,1) = nanmean(sub(sn).rep(ind,5));
        p_right(h,sn,1) = nanmean(sub(sn).a(ind,5)==2);
        
        
        idx = i_RFPC;
        ind = (sub(sn).gameLength == GL(h)) & idx;
        p_highInfo(h,sn,2) = nanmean(sub(sn).hi(ind,5));
        ind = (sub(sn).gameLength == GL(h)) & (sub(sn).n2(:,4)~=sub(sn).n1(:,4)) & idx;
        p_repeat13(h, sn,2) = nanmean(sub(sn).rep(ind,5));
        ind = (sub(sn).gameLength == GL(h)) & (sub(sn).n2(:,4)==sub(sn).n1(:,4)) & idx;
        p_lowMean(h,sn,2) = nanmean(sub(sn).lm(ind,5));
        p_repeat(h, sn,2) = nanmean(sub(sn).rep(ind,5));
        p_right(h,sn,2) = nanmean(sub(sn).a(ind,5)==2);
        
    end
end





 

clear X l l2 t m s e
figure(1); clf;
set(gcf, 'position', [211   137   600   250])
dw = 0.02;
DW = 0.12;
ax = easy_gridOfEqualFigures([0.2  0.2], [0.14 0.17 0.05]);
i = 0;
i=i+1; X(:,i) = p_highInfo(1,:,1); 
i=i+1; X(:,i) = p_highInfo(1,:,2); 
i=i+1; X(:,i) = p_highInfo(2,:,1); vName{i} = 'p(high info), vertex';
i=i+1; X(:,i) = p_highInfo(2,:,2); vName{i} = 'p(high info), RFPC';

i=i+1; X(:,i) = p_lowMean(1,:,1); 
i=i+1; X(:,i) = p_lowMean(1,:,2); vName{i} = 'p(low mean), RFPC, horizon 1';
i=i+1; X(:,i) = p_lowMean(2,:,1); vName{i} = 'p(low mean), vertex, horizon 6';
i=i+1; X(:,i) = p_lowMean(2,:,2); vName{i} = 'p(low mean), RFPC, horizon 6';

i=i+1; X(:,i) = p_right(1,:,1); vName{i} = 'p(right), vertex, horizon 1';
i=i+1; X(:,i) = p_right(1,:,2); vName{i} = 'p(right), RFPC, horizon 1';
i=i+1; X(:,i) = p_right(2,:,1); vName{i} = 'p(right), vertex, horizon 6';
i=i+1; X(:,i) = p_right(2,:,2); vName{i} = 'p(right), RFPC, horizon 6';

clear vName
vName{1} = 'vertex';
vName{2} = 'RFPC';
vName{3} = 'vertex';
vName{4} = 'RFPC';
vName{5} = 'vertex';
vName{6} = 'RFPC';
axes(ax(1)); hold on;

dx = 0;
xx = 1:size(X,2);
xx(1:2:end) = xx(1:2:end)+dx;
xx(2:2:end) = xx(2:2:end)-dx;

set(ax, 'xlim', [0.5 2.5])
axes(ax(1)); hold on;
clear m s
m(1,1) = nanmean(X(:,1)); s(1,1) = nanstd(X(:,1))/sqrt(size(X,1));
m(2,1) = nanmean(X(:,2)); s(2,1) = nanstd(X(:,2))/sqrt(size(X,1));
m(1,2) = nanmean(X(:,3)); s(1,2) = nanstd(X(:,3))/sqrt(size(X,1));
m(2,2) = nanmean(X(:,4)); s(2,2) = nanstd(X(:,4))/sqrt(size(X,1));
e = errorbar(m, s)
ylabel('p(high info)')
xlabel('stimulation condition')
t = text(0, 0, 'directed exploration');
leg = legend(e([2 1]), {'horizon 6' 'horizon 1'}, ...
    'orientation', 'vertical', 'location', 'south')

tt = text(1.5, 0.565, '*')
tt(2) = text(0.65, (m(1,1)+m(1,2))/2+0.01, '**');
tt(3) = text(2.3, (m(2,1)+m(2,2))/2+0.004, '*');
ll = plot([0.84 0.72 0.72 0.84]+0.05, [m(1,1) m(1,1) m(1,2) m(1,2)]);
ll(2) = plot([2.2 2.32 2.32 2.2]-0.1, [m(2,1) m(2,1) m(2,2) m(2,2)]);


axes(ax(2)); hold on;
clear m s
m(1,1) = nanmean(X(:,5)); s(1,1) = nanstd(X(:,5))/sqrt(size(X,1));
m(2,1) = nanmean(X(:,6)); s(2,1) = nanstd(X(:,6))/sqrt(size(X,1));
m(1,2) = nanmean(X(:,7)); s(1,2) = nanstd(X(:,7))/sqrt(size(X,1));
m(2,2) = nanmean(X(:,8)); s(2,2) = nanstd(X(:,8))/sqrt(size(X,1));
e(2,:) = errorbar(m, s)
ll(3) = plot([0.84 0.72 0.72 0.84]+0.05, [m(1,1) m(1,1) m(1,2) m(1,2)]);
ll(4) = plot([2.2 2.32 2.32 2.2]-0.1, [m(2,1) m(2,1) m(2,2) m(2,2)]);
tt(4) = text(0.65, (m(1,1)+m(1,2))/2+0.01, '**');
tt(5) = text(2.3, (m(2,1)+m(2,2))/2+0.004, '*');

ylabel('p(low mean)')
xlabel('stimulation condition')
t(2) = text(0, 0, 'random exploration');

set(t, 'fontsize', 18, 'units', 'normalized', 'position', [0.5 1.15], ...
    'horizontalAlignment', 'center', 'fontweight', 'bold')
set(ax(1), 'xticklabel', {vName{1:2}} ,'ylim', [0.4 0.6])
set(ax(2), 'xticklabel', {vName{3:4}},'ylim', [0.1 0.301])
set(ax, 'view', [0 90], 'xtick', xx, ...
    'tickdir', 'out', 'ytick', [0:0.1:1]);
set(e, 'marker', '.', 'markersize', 30, 'linestyle', '-')
set(e(:,1), 'color', AZblue)
set(e(:,2), 'color', AZred)
f = 0.6;

set(tt, 'horizontalalignment', 'center', 'fontsize', 24, ...
    'verticalalignment', 'top')
set(ll, 'color', 'k', 'linewidth', 1)

addABCs(ax, [-0.09 0.07], 32)

%% Figure 6: Model-free analysis of all trials 
clear p_highInfo p_lowMean p_repeat p_repeat13 p_repeat22
GL = [5 10];
for sn = 1:length(sub)
    for h = 1:length(GL)
        for t = 1:6
            i_RFPC = strcmp(sub(sn).expt_name, 'RFPC');
            i_vertex = strcmp(sub(sn).expt_name, 'vertex');
            
            
            idx = i_vertex;
            ind = (sub(sn).gameLength == GL(h)) & idx;
            p_highInfo(h, t, 1, sn) = nanmean(sub(sn).hi(ind,t+4));
            p_lowMean(h, t, 1,sn) = nanmean(sub(sn).lm(ind,t+4));
            p_repeat(h, t, 1, sn) = nanmean(sub(sn).rep(ind,t+4));
            
            idx = i_RFPC;
            ind = (sub(sn).gameLength == GL(h)) & idx;
            p_highInfo(h, t, 2, sn) = nanmean(sub(sn).hi(ind,t+4));
            p_lowMean(h, t, 2,sn) = nanmean(sub(sn).lm(ind,t+4));
            p_repeat(h, t, 2, sn) = nanmean(sub(sn).rep(ind,t+4));
        end
    end
end

figure(1); clf;
set(gcf, 'position', [588   576   600   500]);
ax = easy_gridOfEqualFigures([0.1 0.18 0.1], [0.16 0.18  0.02]);



axes(ax(1)); hold on;
xlim([0.5 6.5])
M = nanmean(p_highInfo,4);
S = nanstd(p_highInfo,[],4)/sqrt(length(sub));
e = errorbar(M(:,:,1)', S(:,:,1)');
e(2,:) = errorbar(M(:,:,2)', S(:,:,2)');
xlabel('trial number')
ylabel('p(high info)')
title('"directed"')
leg = legend(e(:), {'vertex, horizon 1' 'RFPC, horizon 1' 'vertex, horizon 6' 'RFPC, horizon 6'})
set(leg, 'position', [0.2658    0.7610    0.2425    0.1150])

axes(ax(2)); hold on;
xlim([0.5 6.5])
M = nanmean(p_lowMean,4);
S = nanstd(p_lowMean,[],4)/sqrt(length(sub));
e(3,:) = errorbar(M(:,:,1)', S(:,:,1)');
e(4,:) = errorbar(M(:,:,2)', S(:,:,2)');
xlabel('trial number')
ylabel('p(low mean)')
title('"random"')

set(e([1 3],1), 'color', AZblue)
set(e([2 4],1), 'color', AZblue*0.5+0.5*[1 1 1])
set(e([1 3],2), 'color', AZred)
set(e([2 4],2), 'color', AZred*0.5+0.5*[1 1 1], 'linestyle', '--')
set(e([1 3],:), 'marker', '+')
set(e([2 4],:), 'marker', 'x')
set(e, 'markersize', 10);%, 'marker', '.')
set(ax(1:2), 'ylim', [0.1 0.6], 'ytick', [0:0.2:1], 'xlim', [0.5 6.5], 'xtick', [1:6])

set(ax, 'tickdir', 'out')
addABCs(ax, [-0.1 0.05], 32)

% plot difference
d_hi = squeeze(p_highInfo(:,:,2,:)-p_highInfo(:,:,1,:));
d_lo = squeeze(p_lowMean(:,:,2,:)-p_lowMean(:,:,1,:));


axes(ax(3)); hold on;
M = nanmean(d_hi,3);
S = nanstd(d_hi,[],3)/sqrt(length(sub));
plot([0 7], [0 0], 'k--', 'linewidth', 1)
e = errorbar(M', S');
legend(e, {'horizon 1' 'horizon 6'}, 'location', 'southeast')
xlabel('trial number')
ylabel({'\Deltap(high info)' 'RFPC - vertex'})

axes(ax(4)); hold on;
M = nanmean(d_lo,3);
S = nanstd(d_lo,[],3)/sqrt(length(sub));
plot([0 7], [0 0], 'k--', 'linewidth', 1)
e(2,:) = errorbar(M', S');
xlabel('trial number')
ylabel({'\Deltap(low mean)' 'RFPC - vertex'})

[~,p] = ttest(squeeze(p_highInfo(2,:,2,:)-p_highInfo(2,:,1,:))')
[~,p] = ttest(squeeze(p_lowMean(2,:,2,:)-p_lowMean(2,:,1,:))')
set(ax, 'xlim', [0.5 6.5], 'xtick', [1:6], 'tickdir', 'out')
set(e(:,1), 'color', AZblue)
set(e(:,2), 'color', AZred)
set(e, 'markersize', 30, 'marker', '.')
set(ax(3:4), 'ylim', [-0.1 0.05], 'xlim', [0.5 6.5])


%% ========================================================================
%% HIERARCHICAL MODEL FIT ON FIRST FREE CHOICE %%
%% HIERARCHICAL MODEL FIT ON FIRST FREE CHOICE %%
%% HIERARCHICAL MODEL FIT ON FIRST FREE CHOICE %%
%% ========================================================================

%% prep data structure 
clear a
L = unique(sub(1).gameLength);
i = 1;
NS = length(sub);   % number of subjects
T = 4;              % number of forced choices
U = 2;              % number of uncertainty conditions

a = zeros(NS, 320, T);
c5 = nan(NS, 320);
r = zeros(NS, 320, T);
UC = nan(NS, 320);
GL = nan(NS, 320);

for sn = 1:length(sub)
    
    % choices on forced trials
    dum = sub(sn).a(:,1:4);
    a(sn,1:size(dum,1),:) = dum;
    
    % choices on free trial
    % note a slight hacky feel here - a is 1 or 2, c5 is 0 or 1.
    dum = sub(sn).a(:,5) == 2;
    L(sn) = length(dum);
    c5(sn,1:size(dum,1)) = dum;
    
    % rewards
    dum = sub(sn).r(:,1:4);
    r(sn,1:size(dum,1),:) = dum;
    
    % game length
    dum = sub(sn).gameLength;
    GL(sn,1:size(dum,1)) = dum;
    
    G(sn) = length(dum);
    
    % uncertainty condition 
    dum = abs(sub(sn).uc - 2) + 1;
    UC(sn, 1:size(dum,1)) = dum;
    
    % difference in information
    dum = sub(sn).uc - 2;
    dI(sn, 1:size(dum,1)) = -dum;
    
    % TMS flag
    dum = strcmp(sub(sn).expt_name, 'RFPC');
    TMS(sn,1:size(dum,1)) = dum;
    
end

dum = GL(:); dum(dum==0) = [];
H = length(unique(dum));
dum = UC(:); dum(dum==0) = [];
U = length(unique(dum));
GL(GL==5) = 1;
GL(GL==10) = 2;

C1 = (GL-1)*2+UC
C2 = TMS + 1;
nC1 = 4;
nC2 = 2;

% meaning of condition 1
% gl uc c1
%  1  1  1 - horizon 1, [2 2]
%  1  2  2 - horizon 6, [1 3]
%  2  1  3 - horizon 1, [2 2]
%  2  2  4 - horizon 6, [1 3]



datastruct = struct(...
    'C1', C1, 'C2', C2, 'nC1', nC1, 'nC2', nC2, ...
    'NS', NS, 'G',  G,  'T',   T, ...
    'dI', dI, 'a',  a,  'c5',  c5, 'r', r);

%% run hierarchical model fits! 

% 750s (12.5 minutes) for 1500 samples
% 
cd(fundir)
nchains = 4;
nburnin = 500;
nsamples = 1000; 
thin = 1;
% MCMC parameters for JAGS


% Initialize values all latent variables in all chains
clear S init0
for i=1:nchains
    
    S.a0(1:nC2) = 1;
    S.b0(1:nC2) = 1;
    S.a_inf(1:nC2) = 1;
    S.b_inf(1:nC2) = 1;
    S.AA(1:NS,1:nC1,1:nC2) = 0;
    S.BB(1:NS,1:nC1,1:nC2) = 100;
   
    init0(i) = S;
end

% Use JAGS to Sample
tic

doparallel = 1;
% if doparallel
%     parpool;
% end
fprintf( 'Running JAGS\n' );
[samples, stats ] = matjags( ...
    datastruct, ...
    fullfile(pwd, 'model_KFcond_v1'), ...
    init0, ...
    'doparallel' , doparallel, ...
    'nchains', nchains,...
    'nburnin', nburnin,...
    'nsamples', nsamples, ...
    'thin', thin, ...
    'monitorparams', ...
    {'a0' 'b0' 'alpha_start' 'alpha0' 'alpha_d' ...
    'a_inf' 'b_inf' 'alpha_inf' ...
    'mu0_mean' 'mu0_sigma' 'mu0' ...
    'AA_mean' 'AA_sigma' 'AA' ...
    'SB_mean' 'SB_sigma' 'SB' ...
    'BB_mean' 'BB' ...
    }, ...
    'savejagsoutput' , 1 , ...
    'verbosity' , 1 , ...
    'cleanup' , 0  );
toc

%% throw out first N samples
N = 1;

stats.mean.SB = squeeze(mean(mean(samples.SB(:,N:end,:,:,:),2),1));
stats.mean.BB = squeeze(mean(mean(samples.BB(:,N:end,:,:,:),2),1));
stats.mean.AA = squeeze(mean(mean(samples.AA(:,N:end,:,:,:),2),1));

stats.mean.mu0 = squeeze(mean(mean(samples.mu0(:,N:end,:,:),2),1));
stats.mean.alpha_start = squeeze(mean(mean(samples.alpha_start(:,N:end,:,:),2),1));
stats.mean.alpha_inf = squeeze(mean(mean(samples.alpha_inf(:,N:end,:,:),2),1));
stats.mean.alpha0= squeeze(mean(mean(samples.alpha0(:,N:end,:,:),2),1));
stats.mean.alpha_d = squeeze(mean(mean(samples.alpha_d(:,N:end,:,:),2),1));

%% Figure 4: Model-based analysis of the first free-choice trial 
clear X dd l t tt bins
i = 0;
% mean of prior
i=i+1; dum = samples.mu0_mean(:,:,1); X(:,i) = dum(:); 
i=i+1; dum = samples.mu0_mean(:,:,2); X(:,i) = dum(:);

% alpha_0
i=i+1; dum = samples.a0(:,:,1)./(samples.a0(:,:,1)+samples.b0(:,:,1)); X(:,i) = dum(:); 
i=i+1; dum = samples.a0(:,:,2)./(samples.a0(:,:,2)+samples.b0(:,:,2)); X(:,i) = dum(:);

% alpha_inf
i=i+1; dum = samples.a_inf(:,:,1)./(samples.a_inf(:,:,1)+samples.b_inf(:,:,1)); X(:,i) = dum(:);
i=i+1; dum = samples.a_inf(:,:,2)./(samples.a_inf(:,:,2)+samples.b_inf(:,:,2)); X(:,i) = dum(:);

% info bonus
i=i+1; dum = samples.AA_mean(:,:,2,1); X(:,i) = dum(:);i
i=i+1; dum = samples.AA_mean(:,:,2,2); X(:,i) = dum(:);
i=i+1; dum = samples.AA_mean(:,:,4,1); X(:,i) = dum(:);
i=i+1; dum = samples.AA_mean(:,:,4,2); X(:,i) = dum(:);

% noise [1 3]
i=i+1; dum = samples.BB_mean(:,:,2,1); X(:,i) = dum(:);
i=i+1; dum = samples.BB_mean(:,:,2,2); X(:,i) = dum(:);
i=i+1; dum = samples.BB_mean(:,:,4,1); X(:,i) = dum(:);
i=i+1; dum = samples.BB_mean(:,:,4,2); X(:,i) = dum(:);

% noise [2 2]
i=i+1; dum = samples.BB_mean(:,:,1,1); X(:,i) = dum(:);
i=i+1; dum = samples.BB_mean(:,:,1,2); X(:,i) = dum(:);
i=i+1; dum = samples.BB_mean(:,:,3,1); X(:,i) = dum(:);
i=i+1; dum = samples.BB_mean(:,:,3,2); X(:,i) = dum(:);

% spatial bias [1 3]
i=i+1; dum = samples.SB_mean(:,:,2,1); X(:,i) = dum(:);
i=i+1; dum = samples.SB_mean(:,:,2,2); X(:,i) = dum(:);
i=i+1; dum = samples.SB_mean(:,:,4,1); X(:,i) = dum(:);
i=i+1; dum = samples.SB_mean(:,:,4,2); X(:,i) = dum(:);

% spatial bias [2 2]
i=i+1; dum = samples.SB_mean(:,:,1,1); X(:,i) = dum(:);
i=i+1; dum = samples.SB_mean(:,:,1,2); X(:,i) = dum(:);
i=i+1; dum = samples.SB_mean(:,:,3,1); X(:,i) = dum(:);
i=i+1; dum = samples.SB_mean(:,:,3,2); X(:,i) = dum(:);

D = X(:,2:2:end) - X(:,1:2:end);

% note zero must be in middle of range
bins{1} = [0:100];
bins{2} = [0:0.01:1];
bins{3} = [0:0.01:1];
bins{4} = [-10:0.1:10];
bins{5} = [-10:0.1:10];
bins{6} = [0:0.1:20];
bins{7} = [0:0.1:20];
bins{8} = [0:0.1:20];
bins{9} = [0:0.1:20];
bins{10} = [-10:0.1:10];
bins{11} = [-10:0.1:10];
bins{12} = [-10:0.1:10];
bins{13} = [-10:0.1:10];

vName{1} = {'prior mean, \mu_0'};
vName{2} = {'initial learning rate, \alpha_0'};
vName{3} = {'asymptotic learning rate, \alpha_{inf}'};
vName{4} = {'information bonus (horizon 1), A'};
vName{5} = {'information bonus (horizon 6), A'};
vName{6} = {'decision noise (horizon 1, [1 3]), \sigma'};
vName{7} = {'decision noise (horizon 6, [1 3]), \sigma'};
vName{8} = {'decision noise (horizon 1, [2 2]), \sigma'};
vName{9} = {'decision noise (horizon 6, [2 2]), \sigma'};
vName{10} = {'spatial bias (horizon 1, [1 3]), B'};
vName{11} = {'spatial bias (horizon 6, [1 3]), B'};
vName{12} = {'spatial bias (horizon 1, [2 2]), B'};
vName{13} = {'spatial bias (horizon 6, [2 2]), B'};

figure(1); clf;
set(gcf, 'position', [211   137   700   650])
dw = 0.02;
DW = 0.08;
[~,~,~,ax] = easy_gridOfEqualFigures([0.1 dw dw dw DW dw dw dw DW dw DW dw DW 0.1], [0.45 0.1 0.03]);

clear l t
count = 1;
for i = 1:2:size(X, 2)
    axes(ax(count,1)); hold on;
    [h,x] = hist(X(:,i), bins{count});
    l(i) = plot(x, h/max(h), 'k-', 'linewidth', 2);
    [h,x] = hist(X(:,i+1), bins{count});
    l(i+1) = plot(x, h/max(h), 'k-', 'linewidth', 2);
    t(count) = text(0, 0, vName{count}, 'fontsize', 18);
    xlim([min(x) max(x)])
    count = count + 1;
end

set(ax, 'tickdir', 'out', 'yticklabel', [])
set(ax([2 4 6 7 8 10 11 12],:), 'xticklabel', [])
set(l(1:2:end), 'color', [1 0.5 0])
set(l(2:2:end), 'color', [0 0.5 1])

xlabel({'parameter value'});
set(t, 'units', 'normalized', 'position', [-0.05 1], ...
    'horizontalalignment', 'right', ...
    'verticalalignment', 'top')


clear l


% note zero must be in middle of range
bins{1} = [-40:0.5:40];
bins{2} = [-0.5:0.01:0.5];
bins{3} = [-0.5:0.01:0.5];
bins{4} = [-12:0.1:12];
bins{5} = [-12:0.1:12];
bins{6} = [-12:0.1:12];
bins{7} = [-12:0.1:12];
bins{8} = [-12:0.1:12];
bins{9} = [-12:0.1:12];
bins{10} = [-12:0.1:12];
bins{11} = [-12:0.1:12];
bins{12} = [-12:0.1:12];
bins{13} = [-12:0.1:12];

vName{1} = {'prior mean, \mu_0'};
vName{2} = {'initial learning rate, \alpha_0'};
vName{3} = {'asymptotic learning rate, \alpha_{inf}'};
vName{4} = {'information bonus (horizon 1), A'};
vName{5} = {'information bonus (horizon 6), A'};
vName{6} = {'decision noise (horizon 1, [1 3]), \sigma'};
vName{7} = {'decision noise (horizon 6, [1 3]), \sigma'};
vName{8} = {'decision noise (horizon 1, [2 2]), \sigma'};
vName{9} = {'decision noise (horizon 6, [2 2]), \sigma'};
vName{10} = {'spatial bias (horizon 1, [1 3]), B'};
vName{11} = {'spatial bias (horizon 6, [1 3]), B'};
vName{12} = {'spatial bias (horizon 1, [2 2]), B'};
vName{13} = {'spatial bias (horizon 6, [2 2]), B'};

clear l t
for i = 1:size(D, 2)
    axes(ax(i,2)); hold on;
    [h,x] = hist(D(:,i), bins{i});
    plot([0 0], [0 1], '-', 'linewidth', 1, 'color', [1 1 1]*0.75)
    l(i) = plot(x, h/max(h), 'k-', 'linewidth', 2);
    PL0(i) = mean(D(:,i)<0);
    xlim([min(x) max(x)])
    
end

xlabel({'parameter change' });

axes(ax(1)); 
tt = text(0, 0, 'mean', 'fontsize', 18);
axes(ax(1,2)); 
tt(2) = text(0, 0, {'difference' '(RFPC - vertex)'}, 'fontsize', 18);
set(tt, 'units', 'normalized', 'position', [0.5 3.7], ...
    'horizontalalignment', 'center', ...
    'verticalalignment', 'middle', 'fontweight', 'bold')
axes(ax(1));
leg = legend({'vertex' 'RFPC'}, 'orientation', 'horizontal');

set(leg, 'position', [0.4454    0.9177    0.2232    0.0269])

%% Figure 5: Horizon-dependent changes in directed exploration are reduced by TMS of RFPC
clear Y1
dn = 2;
figure(1); clf;
set(gcf, 'position', [594   524   835   300])
Y1(:,1) = X(:,2)-X(:,1);
Y1(:,2) = X(:,8)-X(:,7);
Y1(:,3) = X(:,10)-X(:,9);
Y1(:,4) = Y1(:,3) - Y1(:,2);

ax = easy_gridOfEqualFigures([0.25 0.1], [0.15 0.17 0.17 0.05]);

i = 1;
j = 2;
axes(ax(1));  hold on;
x1 = Y1(1:dn:end,i);
x2 = Y1(1:dn:end,[2]);
l = plot(x1, x2, '.', ...
    'markersize', 5, ...
    'linewidth', 1)
[r,p] = corr(x1, x2); 
r = round(r*100)/100;
set(l(1), 'color', AZblue)
plot([-20 40], [0 0], '--', 'color', [1 1 1]*0, 'linewidth', 1)
text(-18, -15, ['r = ' num2str(r)], 'fontsize', 18);
xlabel({'change in prior mean' 'R_0(vertex) - R_0(RFPC)'})
ylabel({'change in horizon 1' 'information bonus' 'A(vertex) - A(RFPC)'})

axes(ax(2)); hold on;
x1 = Y1(1:dn:end,i);
x2 = Y1(1:dn:end,[3]);
l(2) = plot(x1, x2, '.', ...
    'markersize', 5, ...
    'linewidth', 1)
[r,p] = corr(x1, x2);
set(l(2), 'color', AZred);
r = round(r*100)/100;
plot([-20 40], [0 0], '--', 'color', [1 1 1]*0, 'linewidth', 1)
text(-18, -15, ['r = ' num2str(r)], 'fontsize', 18);
xlabel({'change in prior mean' 'R_0(vertex) - R_0(RFPC)'})
ylabel({'change in horizon 6' 'information bonus' 'A(vertex) - A(RFPC)'})

axes(ax(3));  hold on;
x1 = Y1(1:dn:end,i);
x2 = Y1(1:dn:end,[4]);
[r,p] = corr(x1, x2);
r = round(r*100)/100;
text(-18, -15, ['r = ' num2str(r)], 'fontsize', 18);
xlabel({'change in prior mean' 'R_0(vertex) - R_0(RFPC)'})
ylabel({'change in horizon effect' ' on information bonus' '\DeltaA(vertex) - \DeltaA(RFPC)'})

l = plot(x1, x2, '.', ...
    'markersize', 5, ...
    'linewidth', 1, 'color', [1 1 1]*0.4);
plot([-20 40], [0 0], '--', 'color', [1 1 1]*0, 'linewidth', 1)
set(ax, 'xlim', [-20 40], 'ylim', [-16 6], 'tickdir', 'out')
addABCs(ax, [-0.13 0.05], 28)

